import numpy as np
import pandas as pd
import os
import copy
import cvxpy as cp
import gurobipy as gp
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
from sample_split import split_sample_cv


def CV_frontier(df, Y_column, G_column, X_columns, obj_list,
                n_splits=5, seed=42, lmd=0.001, solver=cp.GUROBI,
                verbose=True, standardize=True):
    e_b_list = []
    e_r_list = []
    
    n = len(obj_list)
    print(f'number of objectives: {n}')
    counter = 0
    
    for obj in obj_list:
        counter += 1
        print(f'iteration: {counter}/{n}')
        res = CV_frontier_main(df, Y_column, G_column, X_columns, obj,
                               n_splits=n_splits, seed=seed, lmd=lmd,
                               solver=solver, verbose=verbose,
                               standardize=standardize)
        e_b_list.append(res['e_b'])
        e_r_list.append(res['e_r'])
    
    res = {}
    res['obj_list'] = obj_list
    res['e_b_list'] = e_b_list
    res['e_r_list'] = e_r_list
    
    return res


def CV_frontier_main(df, Y_column, G_column, X_columns, obj=np.array([1.0,0.0]),
                     n_splits=5, seed=42, lmd=0.001, solver=cp.GUROBI,
                     verbose=True, standardize=True):
    """Cross-validation for classification with LP
    
    Parameters
    ----------
    df : pandas.DataFrame
        whole dataset
    Y_column : str
        outcome variable
    G_column : str
        group variable
    X_columns : list of str
        covariates
    obj : np.array, optional
        specifies the objective function of LP
        [a, b]: a*e_b + b*e_r, by default np.array([1.0,0.0])
    n_splits : int, optional
        number of splits, by default 5
    seed : int, optional
        for randomization, by default 42
    lmd: float, 0.001 by default
        parameter for regularization
    solver: cp.SCS, cp.ECOS, cp.OSQP, cp.GUROBI, by default cp.GUROBI
        solver for LP
    verbose : bool, optional
        print out the results, by default True
    standardize : bool, optional
        standardize the covariates, by default True
    
    Returns
    -------
    res : dict
        a dictionary of the results
    """
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
    df = df.reset_index(drop=True, inplace=False)
    e_b_list = []
    e_r_list = []
    beta_lp_list = []
    
    counter = 0
    for train_index, test_index in kf.split(df):
        if verbose:
            print(f'iteration: {counter}')
        counter += 1
        
        df_X_test_blue, df_y_test_blue, df_X_test_red, df_y_test_red, train =\
            split_sample_cv(G_column, X_columns, Y_column, df,
                            train_index, test_index,
                            standardize=standardize)

        # LP
        df_X_train = train[X_columns].copy()
        df_X_train['constant'] = 1
        X_train = df_X_train.values
        y_train = train[Y_column].copy()
        y_train = y_train.values
        g_train = train[G_column].copy()
        g_train = g_train.values
        
        # THE FOLLOWING SHOULD BE MODIFIED
        beta_lp = lp_frontier(X_train, y_train, g_train, obj=obj,
                              lmd=lmd, solver=solver, verbose=verbose)
        
        # estimate points on the frontier
        df_X_test_blue['constant'] = 1
        df_X_test_red['constant'] = 1
        res_temp = est_group_loss_lp(beta_lp, df_X_test_blue, df_y_test_blue,
                                     df_X_test_red, df_y_test_red)
        e_b = res_temp['e_b']
        e_r = res_temp['e_r']
        e_b_list.append(e_b)
        e_r_list.append(e_r)
        beta_lp_list.append(beta_lp)

    e_b = np.mean(e_b_list)
    e_r = np.mean(e_r_list)
    
    res = {}
    res['obj'] = obj
    res['e_b'] = e_b
    res['e_r'] = e_r
    res['beta_lp_list'] = beta_lp_list
    res['e_b_list'] = e_b_list
    res['e_r_list'] = e_r_list
    
    return res


def lp_frontier(X_train_std, y_train, g_train, obj=np.array([1.0, 0.0]),
                lmd=0.001, solver=cp.GUROBI, verbose=False):
    '''
    Solve LP with surrogate loss (hinge loss approximating 0-1 loss) with
    L1-regularization
    Use misclassification loss for both accuracy and fairness
    
    Parameters
    ----------
    X_train_std : numpy array
        standardized training dataset.
        last column should be 1 (a constant term)
    y_train : numpy array
        outcome variable
    g_train : numpy array
        group variable
    obj : numpy array, by default np.array([1.0, 0.0])
        specifies the objective function of LP
    lmd: float, 0.001 by default
        parameter for regularization
    '''
    n_cov = X_train_std.shape[1]
    
    I_b_train = np.nonzero(g_train)[0]
    I_r_train = np.nonzero(g_train - 1)[0]
    
    y_train = y_train.reshape((y_train.shape[0], 1))
    y_b_train = y_train[I_b_train]
    y_b_train = y_b_train.reshape((y_b_train.shape[0], 1))
    y_r_train = y_train[I_r_train]
    y_r_train = y_r_train.reshape((y_r_train.shape[0], 1))
    X_b_train = X_train_std[I_b_train, :]
    X_r_train = X_train_std[I_r_train, :]
    n_b = y_b_train.shape[0]
    n_r = y_r_train.shape[0]
    
    alpha_b = obj[0]
    alpha_r = obj[1]
    if np.isclose(0, alpha_b):
        alpha_b = 0
    if np.isclose(0, alpha_r):
        alpha_r = 0
    sign_b = np.sign(alpha_b)
    sign_r = np.sign(alpha_r)
    
    # define variables for LP
    beta_plus = cp.Variable((n_cov,1))
    beta_minus = cp.Variable((n_cov,1))
    # w = cp.Variable((n_b+n_r,1))
    w_b = cp.Variable((n_b,1))
    w_r = cp.Variable((n_r,1))
    e_r = cp.Variable()
    e_b = cp.Variable()
    
    constr = [
        w_b >= 0, w_r >= 0, beta_plus >= 0, beta_minus >= 0
    ]
    
    if sign_b >= 0 and sign_r >= 0:
        constr += [
            e_b == cp.sum(w_b[:, 0])/n_b +\
                lmd*cp.sum(beta_plus[:,0]) + lmd*cp.sum(beta_minus[:,0]),
            e_r == cp.sum(w_r[:, 0])/n_r +\
                lmd*cp.sum(beta_plus[:,0]) + lmd*cp.sum(beta_minus[:,0]),
            w_b >= 1 - cp.multiply((2*y_b_train -1), X_b_train @
                                    (beta_plus - beta_minus)),
            w_r >= 1 - cp.multiply((2*y_r_train -1), X_r_train @
                                    (beta_plus - beta_minus))
        ]
    
    if sign_b >= 0 and sign_r < 0:
        alpha_r = np.abs(alpha_r)
        constr += [
            e_b == cp.sum(w_b[:, 0])/n_b +\
                lmd*cp.sum(beta_plus[:,0]) + lmd*cp.sum(beta_minus[:,0]),
            e_r == cp.sum(w_r[:, 0])/n_r +\
                lmd*cp.sum(beta_plus[:,0]) + lmd*cp.sum(beta_minus[:,0]),
            w_b >= 1 - cp.multiply((2*y_b_train -1), X_b_train @
                                    (beta_plus - beta_minus)),
            w_r >= cp.multiply((2*y_r_train -1), X_r_train @
                                    (beta_plus - beta_minus)) + 1
        ]
    
    if sign_b < 0 and sign_r >= 0:
        alpha_b = np.abs(alpha_b)
        constr += [
            e_b == cp.sum(w_b[:, 0])/n_b +\
                lmd*cp.sum(beta_plus[:,0]) + lmd*cp.sum(beta_minus[:,0]),
            e_r == cp.sum(w_r[:, 0])/n_r +\
                lmd*cp.sum(beta_plus[:,0]) + lmd*cp.sum(beta_minus[:,0]),
            w_r >= 1 - cp.multiply((2*y_r_train -1), X_r_train @
                                    (beta_plus - beta_minus)),
            w_b >= cp.multiply((2*y_b_train -1), X_b_train @
                                    (beta_plus - beta_minus)) + 1
        ]
    
    if sign_b < 0 and sign_r < 0:
        alpha_b = np.abs(alpha_b)
        alpha_r = np.abs(alpha_r)
        constr += [
            e_b == cp.sum(w_b[:, 0])/n_b +\
                lmd*cp.sum(beta_plus[:,0]) + lmd*cp.sum(beta_minus[:,0]),
            e_r == cp.sum(w_r[:, 0])/n_r +\
                lmd*cp.sum(beta_plus[:,0]) + lmd*cp.sum(beta_minus[:,0]),
            w_r >= cp.multiply((2*y_r_train -1), X_r_train @
                                    (beta_plus - beta_minus)) + 1,
            w_b >= cp.multiply((2*y_b_train -1), X_b_train @
                                    (beta_plus - beta_minus)) + 1
        ]
    
    prob =\
            cp.Problem(cp.Minimize(alpha_b*e_b + alpha_r*e_r), constr)
    
    ## solve
    prob.solve(solver=solver, verbose=verbose)
    if verbose:
        print(prob.status)
    
    # Check if the solver solved the problem
    if solver == cp.GUROBI:
        if prob.status != 'optimal':
            print('gurobi cannot solve the problem')
    
    if solver == cp.ECOS:
        if prob.status != 'optimal':
            print('ECOS cannot solve the problem')
            prob.solve(solver=cp.SCS, verbose=verbose)
            if prob.status != 'optimal':
                print('SCS cannot solve the problem')
                prob.solve(solver=cp.OSQP, verbose=verbose)
                if prob.status != 'optimal':
                    print('OSQP cannot solve the problem')
                    return None
    
    # retrieve values
    beta_lp = beta_plus.value - beta_minus.value
    
    return beta_lp


def est_group_loss_lp(beta, df_X_test_blue, y_test_blue,
                      df_X_test_red, y_test_red):
    '''estimate group losses for classification'''
    # y_pred below is the predicted class labels, not probabilities
    y_pred_blue = (df_X_test_blue.values @ beta > 0).ravel().astype(int)
    y_pred_red = (df_X_test_red.values @ beta > 0).ravel().astype(int)
    
    accuracy_blue = accuracy_score(y_test_blue, y_pred_blue)
    accuracy_red = accuracy_score(y_test_red, y_pred_red)
    e_b = 1 - accuracy_blue
    e_r = 1 - accuracy_red
    res = {}
    res['e_b'] = e_b
    res['e_r'] = e_r
    return res


###############################################################################
# `lp_regularized`, `CV_classification_lp` are not used above
###############################################################################

# def lp_regularized(X_train_std, y_train, g_train, obj='b', lmd=0.001,
#                    solver=cp.GUROBI, verbose=False):
#     '''
#     Solve LP with surrogate loss (hinge loss approximating 0-1 loss) with
#     L1-regularization
#     Use misclassification loss for both accuracy and fairness
    
#     Parameters
#     ----------
#     X_train_std : numpy array
#         standardized training dataset.
#         last column should be 1 (a constant term)
#     y_train : numpy array
#         outcome variable
#     g_train : numpy array
#         group variable
#     obj : string
#         specifies the objective function of LP
#         'b': e_b
#         'r': e_r
#         'f': |e_b - e_r|
#     lmd: float, 0.001 by default
#         parameter for regularization
#     '''
#     n_cov = X_train_std.shape[1]
    
#     I_b_train = np.nonzero(g_train)[0]
#     I_r_train = np.nonzero(g_train - 1)[0]
    
#     y_train = y_train.reshape((y_train.shape[0], 1))
#     y_b_train = y_train[I_b_train]
#     y_b_train = y_b_train.reshape((y_b_train.shape[0], 1))
#     y_r_train = y_train[I_r_train]
#     y_r_train = y_r_train.reshape((y_r_train.shape[0], 1))
#     n_b = y_b_train.shape[0]
#     n_r = y_r_train.shape[0]

#     ## define variables for LP
#     beta_plus = cp.Variable((n_cov,1))
#     beta_minus = cp.Variable((n_cov,1))
#     w = cp.Variable((n_b+n_r,1))
#     e_r = cp.Variable()
#     e_b = cp.Variable()
#     constr = [
#                 e_b == cp.sum(w[I_b_train, 0])/n_b +\
#                 lmd*cp.sum(beta_plus[:,0]) + lmd*cp.sum(beta_minus[:,0]),
#                 e_r == cp.sum(w[I_r_train, 0])/n_r +\
#                 lmd*cp.sum(beta_plus[:,0]) + lmd*cp.sum(beta_minus[:,0]),
#                 w >= 1 - cp.multiply((2*y_train -1), X_train_std @
#                                     (beta_plus - beta_minus)),
#                 w >= 0,
#                 beta_plus >= 0,
#                 beta_minus >= 0,
#                 ]
    
#     ## set objective
#     if obj == 'b':
#         prob = cp.Problem(cp.Minimize(e_b), constr)
#     elif obj == 'r':
#         prob = cp.Problem(cp.Minimize(e_r), constr)
#     elif obj == 'f':
#         prob = cp.Problem(cp.Minimize(cp.abs(e_b - e_r)), constr)
    
#     ## solve
#     prob.solve(solver=solver, verbose=verbose)
#     if verbose:
#         print(prob.status)
    
#     # Check if the solver solved the problem
#     if solver == cp.GUROBI:
#         if prob.status != 'optimal':
#             print('gurobi cannot solve the problem')
    
#     if solver == cp.ECOS:
#         if prob.status != 'optimal':
#             print('ECOS cannot solve the problem')
#             prob.solve(solver=cp.SCS, verbose=verbose)
#             if prob.status != 'optimal':
#                 print('SCS cannot solve the problem')
#                 prob.solve(solver=cp.OSQP, verbose=verbose)
#                 if prob.status != 'optimal':
#                     print('OSQP cannot solve the problem')
#                     return None
    
#     # retrieve values
#     beta_lp = beta_plus.value - beta_minus.value
    
#     return beta_lp


# def CV_classification_lp(df, Y_column, G_column, X_columns, n_splits=5, seed=42,
#                          lmd=0.001, solver=cp.GUROBI, verbose=True,
#                          standardize=True):
#     """Cross-validation for classification with LP

#     Parameters
#     ----------
#     df : pandas.DataFrame
#         whole dataset
#     G_column : str
#         group variable
#     X_columns : list of str
#         covariates
#     Y_column : str
#         outcome variable
#     n_splits : int, optional
#         number of splits, by default 5
#     seed : int, optional
#         for randomization, by default 42
#     obj : string, optional
#         specifies the objective function of LP
#         'b': e_b
#         'r': e_r
#         'f': |e_b - e_f|, by default 'b'
#     lmd: float, 0.001 by default
#         parameter for regularization
#     solver: cp.SCS, cp.ECOS, cp.OSQP, cp.GUROBI, by default cp.GUROBI
#         solver for LP
#     clf : sklearn.base.BaseEstimator, optional
#         classifier from scikit-learn, by default None

#     Returns
#     -------
#     res : dict
#         a dictionary of the results
#     """
#     kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
#     df = df.reset_index(drop=True, inplace=True)
#     e_b_blue_list = []
#     e_r_blue_list = []
#     e_b_red_list = []
#     e_r_red_list = []
#     f_b_list = []
#     f_r_list = []
#     beta_lp_b_list = []
#     beta_lp_r_list = []
#     beta_lp_f_list = []
    
#     counter = 0
#     for train_index, test_index in kf.split(df):
#         print(f'iteration: {counter}')
#         counter += 1
        
#         df_X_test_blue, df_y_test_blue, df_X_test_red, df_y_test_red, train =\
#             split_sample_cv(G_column, X_columns, Y_column, df,
#                             train_index, test_index,
#                             standardize=standardize)

#         # LP
#         df_X_train = train[X_columns]
#         df_X_train['constant'] = 1
#         X_train = df_X_train.values
#         y_train = train[Y_column]
#         y_train = y_train.values
#         g_train = train[G_column]
#         g_train = g_train.values
        
#         ## black-optimal
#         beta_lp_b = lp_regularized(X_train, y_train, g_train, obj='b',
#                                    lmd=lmd, solver=solver, verbose=False)
        
#         ## white-optimal
#         beta_lp_r = lp_regularized(X_train, y_train, g_train, obj='r',
#                                    lmd=lmd, solver=solver, verbose=False)
        
#         ## fairness-optimal
#         beta_lp_f = lp_regularized(X_train, y_train, g_train, obj='f',
#                                    lmd=lmd, solver=solver, verbose=False)
        
#         # delete temporary objects for memory efficiency
#         del df_X_train, X_train, y_train, g_train, train
        
#         # blue
#         df_X_test_blue['constant'] = 1
#         df_X_test_red['constant'] = 1
#         res_blue =\
#             est_group_loss_lp(beta_lp_b, df_X_test_blue, df_y_test_blue,
#                               df_X_test_red, df_y_test_red)
#         e_b = res_blue['e_b']
#         e_r = res_blue['e_r']
#         e_b_blue_list.append(e_b)
#         e_r_blue_list.append(e_r)
#         beta_lp_b_list.append(beta_lp_b)

#         # red
#         res_red = est_group_loss_lp(beta_lp_r, df_X_test_blue, df_y_test_blue,
#                                     df_X_test_red, df_y_test_red)
#         e_b = res_red['e_b']
#         e_r = res_red['e_r']
#         e_b_red_list.append(e_b)
#         e_r_red_list.append(e_r)
#         beta_lp_r_list.append(beta_lp_r)
        
#         # fairness
#         res_fair = est_group_loss_lp(beta_lp_f, df_X_test_blue, df_y_test_blue,
#                                      df_X_test_red, df_y_test_red)
#         f_b = res_fair['e_b']
#         f_r = res_fair['e_r']
#         f_b_list.append(f_b)
#         f_r_list.append(f_r)
#         beta_lp_f_list.append(beta_lp_f)
        
#         # delete temporary objects for memory efficiency
#         del df_X_test_blue, df_X_test_red, df_y_test_blue, df_y_test_red

#     e_b_blue = np.mean(e_b_blue_list)
#     e_r_blue = np.mean(e_r_blue_list)
#     e_b_red = np.mean(e_b_red_list)
#     e_r_red = np.mean(e_r_red_list)
#     f_b = np.mean(f_b_list)
#     f_r = np.mean(f_r_list)
#     res_balance = eval_balance(e_b_blue, e_r_blue, e_b_red, e_r_red)

#     # print out the results
#     if verbose:
#         print()
#         print(f'blue: {G_column}=1')
#         print(f'[blue] e_b={e_b_blue.round(3)}, e_r={e_r_blue.round(3)}, fairness={np.abs(e_b_blue - e_r_blue).round(3)}')
#         print(f'[red] e_b={e_b_red.round(3)}, e_r={e_r_red.round(3)}, fairness={np.abs(e_b_red - e_r_red).round(3)}')
#         print(f'[fair] e_b={f_b.round(3)}, e_r={f_r.round(3)}, fairness={np.abs(f_b - f_r).round(3)}')
#         print(res_balance)
    
#     res = {'e_b_blue': e_b_blue_list, 'e_r_blue': e_r_blue_list,
#            'e_b_red': e_b_red_list, 'e_r_red': e_r_red_list,
#            'f_b': f_b_list, 'f_r': f_r_list,
#            'beta_lp_b': beta_lp_b_list, 'beta_lp_r': beta_lp_r_list,
#            'beta_lp_f': beta_lp_f_list,
#            'balance': res_balance}
#     return res
